{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "name": "cycle_gan.ipynb",
   "provenance": [],
   "collapsed_sections": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "NNM951wGtJnD"
   },
   "source": [
    "Take a look at the [repository](https://github.com/mit-han-lab/gan-compression) for more information"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "wEjn7TTwteo7"
   },
   "source": [
    "# Install"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "ILBLDRBptlJ7",
    "colab": {}
   },
   "source": [
    "!git clone https://github.com/mit-han-lab/gan-compression.git"
   ],
   "execution_count": 0,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "ohyUOwJXuCzO",
    "colab": {}
   },
   "source": [
    "import os\n",
    "os.chdir('gan-compression')\n",
    "print(os.getcwd())"
   ],
   "execution_count": 0,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "9MAafxEuuM3S",
    "colab": {}
   },
   "source": [
    "!pip install -r requirements.txt\n",
    "!pip install --upgrade git+https://github.com/mit-han-lab/torchprofile.git"
   ],
   "execution_count": 0,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "lGIBwW-rFTOv",
    "colab": {}
   },
   "source": [
    "import pickle\n",
    "import time\n",
    "import tqdm\n",
    "\n",
    "import numpy as np\n",
    "import torch"
   ],
   "execution_count": 0,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Jrnb8vrqwDS4"
   },
   "source": [
    "# Datasets\n",
    "Download the dataset horse2zebra for testing."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "yTP8y6_szmZk",
    "colab": {}
   },
   "source": [
    "!bash datasets/download_cyclegan_dataset.sh horse2zebra"
   ],
   "execution_count": 0,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "huMEiWB7vQeU"
   },
   "source": [
    "# Pretrained Models\n",
    "Download the original model and our compressed of horse2zebra dataset."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "-ttwzwbS0L5i",
    "colab": {}
   },
   "source": [
    "!python scripts/download_model.py --model cycle_gan --task horse2zebra --stage full\n",
    "!python scripts/download_model.py --model cycle_gan --task horse2zebra --stage compressed\n",
    "print('Download the pretrained models successfully!!!')"
   ],
   "execution_count": 0,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cPvL2-Sbxwep",
    "colab_type": "text"
   },
   "source": [
    "# Test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rxUAG0iCxwer",
    "colab_type": "text",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Get the options for the test."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "f7zb8niYxwet",
    "colab_type": "code",
    "colab": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "source": [
    "!python get_test_opt.py --dataroot database/horse2zebra/valA \\\n",
    "--dataset_mode single \\\n",
    "--results_dir results-pretrained/cycle_gan/horse2zebra/full \\\n",
    "--ngf 64 --netG resnet_9blocks \\\n",
    "--restore_G_path pretrained/cycle_gan/horse2zebra/full/latest_net_G.pth \\\n",
    "--need_profile --real_stat_path real_stat/horse2zebra_B.npz \\\n",
    "--num_test 0\n",
    "\n",
    "!python get_test_opt.py --dataroot database/horse2zebra/valA \\\n",
    "--dataset_mode single \\\n",
    "--results_dir results-pretrained/cycle_gan/horse2zebra/compressed \\\n",
    "--config_str 16_16_32_16_32_32_16_16 \\\n",
    "--restore_G_path pretrained/cycle_gan/horse2zebra/compressed/latest_net_G.pth \\\n",
    "--need_profile --real_stat_path real_stat/horse2zebra_B.npz \\\n",
    "--num_test 0"
   ],
   "execution_count": 0,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NmgoZT_txwez",
    "colab_type": "text",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Get the dataloader for evaluation."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "yBbqYtXG1ok1",
    "colab": {},
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "source": [
    "from data import create_dataloader\n",
    "with open('opts/opt_full.pkl', 'rb') as f:\n",
    "    opt = pickle.load(f)\n",
    "dataloader = create_dataloader(opt)"
   ],
   "execution_count": 0,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LVREgogaxwe7",
    "colab_type": "text"
   },
   "source": [
    "Get the original full model."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "cClmEVqExwe8",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "from models import create_model\n",
    "\n",
    "with open('opts/opt_full.pkl', 'rb') as f:\n",
    "    opt = pickle.load(f)\n",
    "model_full = create_model(opt, verbose=False)\n",
    "model_full.setup(opt, verbose=False)"
   ],
   "execution_count": 0,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "82pRFmnsxwfD",
    "colab_type": "text"
   },
   "source": [
    "Get our compressed model."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "MN1r7Z7FxwfE",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "with open('opts/opt_compressed.pkl', 'rb') as f:\n",
    "    opt = pickle.load(f)\n",
    "model_ours = create_model(opt, verbose=False)\n",
    "model_ours.setup(opt, verbose=False)"
   ],
   "execution_count": 0,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "k_h4jWpSd8ao",
    "colab_type": "text"
   },
   "source": [
    "Get visual results of the models."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "7ihatGcxeAxr",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "from IPython.display import display\n",
    "from PIL import Image\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "from utils.util import save_image, tensor2im\n",
    "\n",
    "transform_list = [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n",
    "transform = transforms.Compose(transform_list)\n",
    "print('             Input                        Original                      Compressed')\n",
    "imgs_dir = 'imgs/horse2zebra'\n",
    "files = os.listdir(imgs_dir)\n",
    "for file in files:\n",
    "    if not file.endswith('.png'):\n",
    "        continue\n",
    "    base = file.split('.')[0]\n",
    "    path = os.path.join(imgs_dir, file)\n",
    "    A_img = Image.open(path).convert('RGB')\n",
    "    input = transform(A_img).to('cuda:0')\n",
    "    input = input.reshape([1, 3, 256, 256])\n",
    "    output_full = model_full.netG(input).cpu()\n",
    "    output_ours = model_ours.netG(input).cpu()\n",
    "    B_full = tensor2im(output_full)\n",
    "    B_ours = tensor2im(output_ours)\n",
    "    save_image(B_full, 'output/full/%s.png' % base, create_dir=True)\n",
    "    save_image(B_ours, 'output/ours/%s.png' % base, create_dir=True)\n",
    "    p1 = Image.open(path)\n",
    "    p2 = Image.open('output/full/%s.png' % base)\n",
    "    p3 = Image.open('output/ours/%s.png' % base)\n",
    "    all = Image.new('RGB', (256*3, 256))\n",
    "    all.paste(p1, (0, 0))\n",
    "    all.paste(p2, (256, 0))\n",
    "    all.paste(p3, (256*2, 0))\n",
    "    display(all)"
   ],
   "execution_count": 0,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "r6GVJNACxwfJ",
    "colab_type": "text"
   },
   "source": [
    "Profile and evaluate the two models."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "0gaviBh3dIxo",
    "colab": {}
   },
   "source": [
    "def evaluate(model):\n",
    "    fakes = []\n",
    "    for i, data in enumerate(tqdm.tqdm(dataloader)):\n",
    "        model.set_input(data)\n",
    "        if i == 0:\n",
    "            macs, params = model.profile(verbose=False)\n",
    "        model.test()\n",
    "        visuals = model.get_current_visuals()\n",
    "        generated = visuals['fake_B'].cpu()\n",
    "        fakes.append(generated)\n",
    "    # Test the latency of the model\n",
    "    for i in range(100):\n",
    "        model.test()\n",
    "        torch.cuda.synchronize()\n",
    "    start_time = time.time()\n",
    "    for i in range(100):\n",
    "        model.test()\n",
    "        torch.cuda.synchronize()\n",
    "    cost = time.time()-start_time\n",
    "    latency = cost/100\n",
    "    return macs, params, latency, fakes\n",
    "\n",
    "macs_full, params_full, latency_full, fakes_full = evaluate(model_full)\n",
    "macs_ours, params_ours, latency_ours, fakes_ours = evaluate(model_ours)\n",
    "print('\\n\\n')\n",
    "print('Full: %.3fG MACs\\t%.3fM Params\\t%.5fs Latency' % (macs_full/1e9, params_full/1e6, latency_full))\n",
    "print('Ours: %.3fG MACs\\t%.3fM Params\\t%.5fs Latency' % (macs_ours/1e9, params_ours/1e6, latency_ours))"
   ],
   "execution_count": 0,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9D87Wb7Ee-AH",
    "colab_type": "text"
   },
   "source": [
    "Download the statistical information of the ground-truth images for FID computation."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "1vZ5xBraelU6",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "!bash ./datasets/download_real_stat.sh horse2zebra B\n",
    "npz = np.load('real_stat/horse2zebra_B.npz')\n",
    "print('Loaded the statistical information of the ground-truth images.')"
   ],
   "execution_count": 0,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ll_zuhKHe-4E",
    "colab_type": "text"
   },
   "source": [
    "Define the Inception Model for FID computation. (**This may take about 3mins.**)\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "eDKHUq8CetHi",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "from metric.inception import InceptionV3\n",
    "block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]\n",
    "inception_model = InceptionV3([block_idx])\n",
    "inception_model.to('cuda:0')\n",
    "inception_model.eval()\n",
    "print('Successfully built an InceptionV3 model.')"
   ],
   "execution_count": 0,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "iBghJHoCdIxk"
   },
   "source": [
    "Computing the FID of the two models."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "GVv2X-VidIxO",
    "colab": {}
   },
   "source": [
    "from metric import get_fid\n",
    "fid_full = get_fid(fakes_full, inception_model, npz, 'cuda:0', 16)\n",
    "fid_ours = get_fid(fakes_ours, inception_model, npz, 'cuda:0', 16)\n",
    "print('\\n\\n')\n",
    "print('Full FID: %.3f' % fid_full)\n",
    "print('Ours FID: %.3f' % fid_ours)"
   ],
   "execution_count": 0,
   "outputs": []
  }
 ]
}